/* Copyright 2014-2015. The Regents of the University of California.
 * All rights reserved. Use of this source code is governed by
 * a BSD-style license which can be found in the LICENSE file.
 *
 * Authors:
 * 2014 Frank Ong <frankong@berkeley.edu>
 * 2014-2015 Martin Uecker <uecker@eecs.berkeley.edu>
 */

#include <math.h>
#include <complex.h>
#include <assert.h>
#include <stdbool.h>

#include "misc/misc.h"
#include "misc/debug.h"

#include "num/multind.h"
#include "num/flpmath.h"
#include "num/filter.h"
#include "num/fft.h"
#include "num/shuffle.h"

#include "linops/linop.h"
#include "linops/someops.h"

#include "noncart/grid.h"

#include "nufft.h"

#define FFT_FLAGS (MD_BIT(0)|MD_BIT(1)|MD_BIT(2))

struct nufft_conf_s nufft_conf_defaults = {

	.toeplitz = false,
};


/**
 *
 * NUFFT internal data structure
 *
 */
struct nufft_data {

	struct nufft_conf_s conf;

	bool use_gpu;

	unsigned int N;

	const complex float* linphase;
	const complex float* traj;
	const complex float* roll;
	const complex float* psf;
	const complex float* fftmod;
	const complex float* weights;

	complex float* grid;

	float width;
	double beta;

	const struct linop_s* fft_op;

	long* ksp_dims;
	long* cim_dims;
	long* cml_dims;
	long* img_dims;
	long* trj_dims;
	long* lph_dims;
	long* psf_dims;
	long* wgh_dims;

	//!
	long* cm2_dims;

	long* ksp_strs;
	long* cim_strs;
	long* cml_strs;
	long* img_strs;
	long* trj_strs;
	long* lph_strs;
	long* psf_strs;
	long* wgh_strs;
};


static void nufft_free_data(const void* data);
static void nufft_apply(const void* _data, complex float* dst, const complex float* src);
static void nufft_apply_adjoint(const void* _data, complex float* dst, const complex float* src);
static void nufft_apply_normal(const void* _data, complex float* dst, const complex float* src);


static void toeplitz_mult(const struct nufft_data* data, complex float* dst, const complex float* src);
static complex float* compute_linphases(unsigned int N, long lph_dims[N + 3], const long img_dims[N]);
static complex float* compute_psf2(unsigned int N, const long psf_dims[N + 3], const long trj_dims[N], const complex float* traj, const complex float* weights);


/**
 *
 * NUFFT operator initialization
 *
 * @param N		-	number of dimensions
 * @param ksp_dims      -	kspace dimension
 * @param cim_dims	-	coil images dimension
 * @param traj		-	trajectory
 * @param conf          -	configuration options
 * @param use_gpu       -	use gpu boolean
 *
 */
struct linop_s* nufft_create(unsigned int N, const long ksp_dims[N], const long cim_dims[N], const long traj_dims[N], const complex float* traj, const complex float* weights, struct nufft_conf_s conf, bool use_gpu)

{
	struct nufft_data* data = (struct nufft_data*)xmalloc(sizeof(struct nufft_data));

	data->N = N;
	data->use_gpu = use_gpu;
	data->traj = traj;
	data->conf = conf;

	data->width = 3.;
	data->beta = calc_beta(2., data->width);

	// get dims

	assert(md_check_compat(N - 3, 0, ksp_dims + 3, cim_dims + 3));

	unsigned int ND = N + 3;

	data->ksp_dims = xmalloc(ND * sizeof(long));
	data->cim_dims = xmalloc(ND * sizeof(long));
	data->cml_dims = xmalloc(ND * sizeof(long));
	data->img_dims = xmalloc(ND * sizeof(long));
	data->trj_dims = xmalloc(ND * sizeof(long));
	data->lph_dims = xmalloc(ND * sizeof(long));
	data->psf_dims = xmalloc(ND * sizeof(long));
	data->wgh_dims = xmalloc(ND * sizeof(long));

	data->ksp_strs = xmalloc(ND * sizeof(long));
	data->cim_strs = xmalloc(ND * sizeof(long));
	data->cml_strs = xmalloc(ND * sizeof(long));
	data->img_strs = xmalloc(ND * sizeof(long));
	data->trj_strs = xmalloc(ND * sizeof(long));
	data->lph_strs = xmalloc(ND * sizeof(long));
	data->psf_strs = xmalloc(ND * sizeof(long));
	data->wgh_strs = xmalloc(ND * sizeof(long));

	md_singleton_dims(ND, data->cim_dims);
	md_singleton_dims(ND, data->ksp_dims);

	md_copy_dims(N, data->cim_dims, cim_dims);
	md_copy_dims(N, data->ksp_dims, ksp_dims);


	md_select_dims(ND, FFT_FLAGS, data->img_dims, data->cim_dims);

	assert(3 == traj_dims[0]);
	assert(traj_dims[1] == ksp_dims[1]);
	assert(traj_dims[2] == ksp_dims[2]);
	assert(md_check_compat(N - 3, ~0, traj_dims + 3, ksp_dims + 3));
	assert(md_check_bounds(N - 3, ~0, traj_dims + 3, ksp_dims + 3));

	md_singleton_dims(ND, data->trj_dims);
	md_copy_dims(N, data->trj_dims, traj_dims);


	// get strides

	md_calc_strides(ND, data->cim_strs, data->cim_dims, CFL_SIZE);
	md_calc_strides(ND, data->img_strs, data->img_dims, CFL_SIZE);
	md_calc_strides(ND, data->trj_strs, data->trj_dims, CFL_SIZE);
	md_calc_strides(ND, data->ksp_strs, data->ksp_dims, CFL_SIZE);


	data->weights = NULL;

	if (NULL != weights) {

		md_singleton_dims(ND, data->wgh_dims);
		md_select_dims(N, ~MD_BIT(0), data->wgh_dims, data->trj_dims);
		md_calc_strides(ND, data->wgh_strs, data->wgh_dims, CFL_SIZE);

		complex float* tmp = md_alloc(ND, data->wgh_dims, CFL_SIZE);
		md_copy(ND, data->wgh_dims, tmp, weights, CFL_SIZE);
		data->weights = tmp;
	}


	complex float* roll = md_alloc(ND, data->img_dims, CFL_SIZE);
	rolloff_correction(2., data->width, data->beta, data->img_dims, roll);
	data->roll = roll;


	complex float* linphase = compute_linphases(N, data->lph_dims, data->img_dims);

	md_calc_strides(ND, data->lph_strs, data->lph_dims, CFL_SIZE);

	if (!conf.toeplitz)
		md_zmul2(ND, data->lph_dims, data->lph_strs, linphase, data->lph_strs, linphase, data->img_strs, data->roll);


	fftmod(ND, data->lph_dims, FFT_FLAGS, linphase, linphase);
	fftscale(ND, data->lph_dims, FFT_FLAGS, linphase, linphase);
//	md_zsmul(ND, data->lph_dims, linphase, linphase, 1. / (float)(data->trj_dims[1] * data->trj_dims[2]));

	complex float* fftm = md_alloc(ND, data->img_dims, CFL_SIZE);
	md_zfill(ND, data->img_dims, fftm, 1.);
	fftmod(ND, data->img_dims, FFT_FLAGS, fftm, fftm);
	data->fftmod = fftm;



	data->linphase = linphase;
	data->psf = NULL;

	if (conf.toeplitz) {

#if 0
		md_copy_dims(ND, data->psf_dims, data->lph_dims);
#else
		md_copy_dims(3, data->psf_dims, data->lph_dims);
		md_copy_dims(ND - 3, data->psf_dims + 3, data->trj_dims + 3);
		data->psf_dims[N] = data->lph_dims[N];
#endif
		md_calc_strides(ND, data->psf_strs, data->psf_dims, CFL_SIZE);
		data->psf = compute_psf2(N, data->psf_dims, data->trj_dims, data->traj, data->weights);
	}


	md_copy_dims(ND, data->cml_dims, data->cim_dims);
	data->cml_dims[N + 0] = data->lph_dims[N + 0];

	md_calc_strides(ND, data->cml_strs, data->cml_dims, CFL_SIZE);


	data->cm2_dims = xmalloc(ND * sizeof(long));
	// !
	md_copy_dims(ND, data->cm2_dims, data->cim_dims);
	for (int i = 0; i < 3; i++)
		data->cm2_dims[i] = (1 == cim_dims[i]) ? 1 : (2 * cim_dims[i]);



	data->grid = md_alloc(ND, data->cml_dims, CFL_SIZE);

	data->fft_op = linop_fft_create(ND, data->cml_dims, FFT_FLAGS, use_gpu);



	return linop_create(N, ksp_dims, N, cim_dims,
		data, nufft_apply, nufft_apply_adjoint, nufft_apply_normal, NULL, nufft_free_data);
}




static complex float* compute_linphases(unsigned int N, long lph_dims[N + 3], const long img_dims[N + 3])
{
	float shifts[8][3];

	int s = 0;
	for(int i = 0; i < 8; i++) {

		bool skip = false;

		for(int j = 0; j < 3; j++) {

			shifts[s][j] = 0.;

			if (MD_IS_SET(i, j)) {

				skip = skip || (1 == img_dims[j]);
				shifts[s][j] = -0.5;
			}
		}

		if (!skip)
			s++;
	}

	unsigned int ND = N + 3;
	md_select_dims(ND, FFT_FLAGS, lph_dims, img_dims);
	lph_dims[N + 0] = s;

	complex float* linphase = md_alloc(ND, lph_dims, CFL_SIZE);

	for(int i = 0; i < s; i++) {

		float shifts2[ND];
		for (unsigned int j = 0; j < ND; j++)
			shifts2[j] = 0.;

		shifts2[0] = shifts[i][0];
		shifts2[1] = shifts[i][1];
		shifts2[2] = shifts[i][2];

		linear_phase(ND, img_dims, shifts2, 
				linphase + i * md_calc_size(ND, img_dims));
	}

	return linphase;
}


complex float* compute_psf(unsigned int N, const long img2_dims[N], const long trj_dims[N], const complex float* traj, const complex float* weights)
{      
	long ksp_dims1[N];
	md_select_dims(N, ~MD_BIT(0), ksp_dims1, trj_dims);

	struct linop_s* op2 = nufft_create(N, ksp_dims1, img2_dims, trj_dims, traj, NULL,
					nufft_conf_defaults, false);

	complex float* ones = md_alloc(N, ksp_dims1, CFL_SIZE);
	md_zfill(N, ksp_dims1, ones, 1.);

	if (NULL != weights) {

		md_zmul(N, ksp_dims1, ones, ones, weights);
		md_zmulc(N, ksp_dims1, ones, ones, weights);
	}

	complex float* psft = md_alloc(N, img2_dims, CFL_SIZE);

	linop_adjoint_unchecked(op2, psft, ones);

	md_free(ones);
	linop_free(op2);

	return psft;
}


static complex float* compute_psf2(unsigned int N, const long psf_dims[N + 3], const long trj_dims[N + 3], const complex float* traj, const complex float* weights)
{
	unsigned int ND = N + 3;

	long img_dims[ND];
	long img_strs[ND];

	md_select_dims(ND, ~MD_BIT(N + 0), img_dims, psf_dims);
	md_calc_strides(ND, img_strs, img_dims, CFL_SIZE);

	// PSF 2x size

	long img2_dims[ND];
	long img2_strs[ND];

	md_copy_dims(ND, img2_dims, img_dims);

	for (int i = 0; i < 3; i++)
		img2_dims[i] = (1 == img_dims[i]) ? 1 : (2 * img_dims[i]);

	md_calc_strides(ND, img2_strs, img2_dims, CFL_SIZE);

	complex float* traj2 = md_alloc(ND, trj_dims, CFL_SIZE);
	md_zsmul(ND, trj_dims, traj2, traj, 2.);

	complex float* psft = compute_psf(ND, img2_dims, trj_dims, traj2, weights);
	md_free(traj2);

	fftuc(ND, img2_dims, FFT_FLAGS, psft, psft);

	// reformat

	long sub2_strs[ND];
	md_copy_strides(ND, sub2_strs, img2_strs);

	for(int i = 0; i < 3; i++)
		sub2_strs[i] *= 2;;

	complex float* psf = md_alloc(ND, psf_dims, CFL_SIZE);

	long factors[N];

	for (unsigned int i = 0; i < N; i++)
		factors[i] = ((img_dims[i] > 1) && (i < 3)) ? 2 : 1;

	md_decompose(N + 0, factors, psf_dims, psf, img2_dims, psft, CFL_SIZE);

	md_free(psft);
	return psf;
}





// Free nufft operator

static void nufft_free_data(const void* _data)
{
	struct nufft_data* data = (struct nufft_data*)_data;

	free(data->ksp_dims);
	free(data->cim_dims);
	free(data->cml_dims);
	free(data->img_dims);
	free(data->trj_dims);
	free(data->lph_dims);
	free(data->psf_dims);
	free(data->wgh_dims);

	free(data->ksp_strs);
	free(data->cim_strs);
	free(data->cml_strs);
	free(data->img_strs);
	free(data->trj_strs);
	free(data->lph_strs);
	free(data->psf_strs);
	free(data->wgh_strs);

	md_free(data->grid);
	md_free((void*)data->linphase);
	md_free((void*)data->psf);
	md_free((void*)data->fftmod);
	md_free((void*)data->weights);

	linop_free(data->fft_op);

	free(data);
}




// Forward: from image to kspace
static void nufft_apply(const void* _data, complex float* dst, const complex float* src)
{
	struct nufft_data* data = (struct nufft_data*)_data;

	assert(!data->conf.toeplitz); // if toeplitz linphase has no roll, so would need to be added

	unsigned int ND = data->N + 3;

	md_zmul2(ND, data->cml_dims, data->cml_strs, data->grid, data->cim_strs, src, data->lph_strs, data->linphase);
	linop_forward(data->fft_op, ND, data->cml_dims, data->grid, ND, data->cml_dims, data->grid);
	md_zmul2(ND, data->cml_dims, data->cml_strs, data->grid, data->cml_strs, data->grid, data->img_strs, data->fftmod);

	md_clear(ND, data->ksp_dims, dst, CFL_SIZE);

	complex float* gridX = md_alloc(data->N, data->cm2_dims, CFL_SIZE);

	long factors[data->N];

	for (unsigned int i = 0; i < data->N; i++)
		factors[i] = ((data->img_dims[i] > 1) && (i < 3)) ? 2 : 1;

	md_recompose(data->N, factors, data->cm2_dims, gridX, data->cml_dims, data->grid, CFL_SIZE);

	grid2H(2., data->width, data->beta, ND, data->trj_dims, data->traj, data->ksp_dims, dst, data->cm2_dims, gridX);

	md_free(gridX);

	if (NULL != data->weights)
		md_zmul2(data->N, data->ksp_dims, data->ksp_strs, dst, data->ksp_strs, dst, data->wgh_strs, data->weights);
}


// Adjoint: from kspace to image
static void nufft_apply_adjoint(const void* _data, complex float* dst, const complex float* src)
{
	const struct nufft_data* data = _data;

	unsigned int ND = data->N + 3;

	complex float* gridX = md_alloc(data->N, data->cm2_dims, CFL_SIZE);
	md_clear(data->N, data->cm2_dims, gridX, CFL_SIZE);

	complex float* wdat = NULL;

	if (NULL != data->weights) {

		wdat = md_alloc(data->N, data->ksp_dims, CFL_SIZE);
		md_zmulc2(data->N, data->ksp_dims, data->ksp_strs, wdat, data->ksp_strs, src, data->wgh_strs, data->weights);
		src = wdat;
	}

	grid2(2., data->width, data->beta, ND, data->trj_dims, data->traj, data->cm2_dims, gridX, data->ksp_dims, src);

	md_free(wdat);

	long factors[data->N];

	for (unsigned int i = 0; i < data->N; i++)
		factors[i] = ((data->img_dims[i] > 1) && (i < 3)) ? 2 : 1;

	md_decompose(data->N, factors, data->cml_dims, data->grid, data->cm2_dims, gridX, CFL_SIZE);
	md_free(gridX);
	md_zmulc2(ND, data->cml_dims, data->cml_strs, data->grid, data->cml_strs, data->grid, data->img_strs, data->fftmod);
	linop_adjoint(data->fft_op, ND, data->cml_dims, data->grid, ND, data->cml_dims, data->grid);

	md_clear(ND, data->cim_dims, dst, CFL_SIZE);
	md_zfmacc2(ND, data->cml_dims, data->cim_strs, dst, data->cml_strs, data->grid, data->lph_strs, data->linphase);

	if (data->conf.toeplitz)
		md_zmul2(ND, data->cim_dims, data->cim_strs, dst, data->cim_strs, dst, data->img_strs, data->roll);
}



/** 
 *
 */
static void nufft_apply_normal(const void* _data, complex float* dst, const complex float* src)
{
	const struct nufft_data* data = _data;

	if (data->conf.toeplitz) {

		toeplitz_mult(data, dst, src);

	} else {

		complex float* tmp_ksp = md_alloc(data->N + 3, data->ksp_dims, CFL_SIZE);
		nufft_apply((const void*)data, tmp_ksp, src);
		nufft_apply_adjoint((const void*)data, dst, tmp_ksp);
		md_free(tmp_ksp);
	}
}



static void toeplitz_mult(const struct nufft_data* data, complex float* dst, const complex float* src)
{
	unsigned int ND = data->N + 3;

	md_zmul2(ND, data->cml_dims, data->cml_strs, data->grid, data->cim_strs, src, data->lph_strs, data->linphase);

	linop_forward(data->fft_op, ND, data->cml_dims, data->grid, ND, data->cml_dims, data->grid);
	md_zmul2(ND, data->cml_dims, data->cml_strs, data->grid, data->cml_strs, data->grid, data->psf_strs, data->psf);
	linop_adjoint(data->fft_op, ND, data->cml_dims, data->grid, ND, data->cml_dims, data->grid);

	md_clear(ND, data->cim_dims, dst, CFL_SIZE);
	md_zfmacc2(ND, data->cml_dims, data->cim_strs, dst, data->cml_strs, data->grid, data->lph_strs, data->linphase);
}









/**
 * Estimate image dimensions from trajectory
 */
void estimate_im_dims(unsigned int N, long dims[3], const long tdims[N], const complex float* traj)
{
	float max_dims[3] = { 0., 0., 0. };

	for (long i = 0; i < md_calc_size(N - 1, tdims + 1); i++)
		for(int j = 0; j < 3; j++)
			max_dims[j] = MAX(cabsf(traj[j + tdims[0] * i]), max_dims[j]);

	for (int j = 0; j < 3; j++)
		dims[j] = (0. == max_dims[j]) ? 1 : (2 * (long)((2. * max_dims[j] + 1.5) / 2.));
}



